package LiuYun

import spinal.core._
import spinal.lib._
object DCacheConfig extends CacheConfig{
    val w_idx :Int = 8
    val w_pos :Int = 2
    val n_way :Int = 2
    val withBitMask = true
    val withDirty = true
}
class DMiss extends IMiss{
    val occupy :Bool = Bool()
    val arsize :UInt = UInt(3 bits)
    val awvalid:Bool = Bool()
    val awaddr :UInt = LISA.PA
    val awsize :UInt = UInt(3 bits)
    val wdata  :Vec[UInt] = Vec(UInt(32 bits),4)
    val wstrb  :Bits = Bits(4 bits)
    val bvalid :Bool = Bool()
    override def asMaster():Unit = {
        super.asMaster()
        out(occupy,awvalid,arsize,awaddr,awsize,wdata,wstrb)
        in(bvalid)
    }
}
class MemType extends Bundle{
    val st = Bool()
    val ld = Bool()
    val pr = Bool()
    val w  = Bool()
    val h  = Bool()
    val b  = Bool()
    def size:UInt = (B("0") ## w ## h).asUInt
}
class DLookupInfo extends LookupInfo{
    val cfg  = DCacheConfig
    val pt   :UInt = UInt(LISA.w_ppn bits)
    val mat  :UInt = UInt(2 bits)
    val ex   :Bool = Bool()      
    val ecode:UInt = LISA.Ex.Code
    val esubc:UInt = LISA.Ex.Subc
    val cached:Bool = mat === U(1)
    val vidx :UInt  = Reg(UInt(12 bits))
    val mem  :MemType = Reg(new MemType)
    val wdata:UInt = Reg(LISA.GPR)
    val wstrb:Bits = Reg(Bits( 4 bits))
    val cacop:Cacop= Reg(new Cacop)
    def setFields(miss:DMiss, is_new:Bool, res:SavedTlbres):Unit = {
        miss.araddr := Mux(is_new, pa, res.pt@@vidx)
        miss.line   := Mux(cacop.isIdxInv && cacop.valid, True, Mux(is_new, cached, res.cc))
    }
    def send_ex(e:DExcept){
        e.ecode := ecode
        e.esubc := esubc
    }
}

class SavedTlbres extends Bundle{
    val pt :UInt = Reg(UInt(LISA.w_ppn bits))
    val mat :UInt = Reg(UInt(2 bits))
    val cc:Bool = mat === U(1)
    val ex   :Bool = Reg(Bool())      
    val ecode:UInt = Reg(LISA.Ex.Code)
    val esubc:UInt = Reg(LISA.Ex.Subc)
    def tagv  :UInt = U("1") @@ pt
}

class DataReq extends CacheReq{
    val mem:MemType = new MemType
    val wdata :UInt = LISA.GPR
    val wstrb :Bits = Bits( 4 bits)
    val cacop :Cacop = new Cacop
    val translate_ok = Bool()
    def getSize:UInt = mem.size
    override def asMaster():Unit = {
        super.asMaster()
        out(mem,wdata,wstrb,cacop)
        in(translate_ok)
    }
}
class DataResp extends CacheResp
class DCacheCtrl(miss:DMiss,resp:DataResp,recv:BusReceiver) extends Area{
    object CacheState extends SpinalEnum{
        val sIdle,sLookup,sWaitL2,sWaitArready,sWrite,sMiss,sOperate,sWaitIcopRdy,sDied = newElement()
    }
    import CacheState._
    // input fields
    val req_valid  = Bool()
    val excp_ready = Bool()
    val fill_done  = Bool()
    val cached     = Bool()
    val ex         = Bool()
    val skip       = Bool()
    val hit        = Bool()

    val init = RamInit(8)
    val s  = RegInit(sIdle)

    val addr_ok = Bool()
    val miss_valid = Bool()
    val send_miss  = miss_valid && miss.arready

    val mem        = new MemType
    val req_cacop  = new Cacop
    val cacop      = new Cacop
    val l1tlb_resp = new AddrTransResp
    val l2tlb_valid = Bool
    val l2tlb_willresp = Bool
    val l2tlb_ready = Bool
    val ex3_allow = Bool //TODO: not the best way
    //only when ex3 is empty(mainly wait for div finish), l2 req will send
    // 1st stage
    val accept = Bool()
    val req_st_0_tag = req_cacop.valid && req_cacop.isSt0Tag && req_cacop.isDCache
    val req_read_tag = !req_cacop.valid || req_cacop.isDCache && !req_cacop.isSt0Tag
    val req_read_tlb = !req_cacop.valid || req_cacop.isHitInv
    val read_tlb = Bool()
    val read_tag = Bool()
    val st_0_tag = Bool()
    // 2nd stage
    val read_dat = Bool()
    val update   = Bool()
    val inv_wrbk = Bool()
    val wrt_back = Bool()
    val send_cop = Bool()
    val icop_rdy = Bool()
    // 3rd stage
    val replace  = Reg(Bool())
    val dirty    = Bool()
    val cacop_working = Reg(Bool())
    cacop_working.riseWhen(req_cacop.valid)

    val is_my_op   = cacop_working && cacop.isDCache
    val is_hit_inv = cacop_working && cacop.isHitInv
    val my_hit_inv = cacop.isDCache && is_hit_inv
    val my_idx_inv = is_my_op && cacop.isIdxInv
    val do_my_op   = is_my_op   && miss.arready
    val do_hit_inv = my_hit_inv && miss.arready && !ex
    val do_idx_inv = my_idx_inv && miss.arready

    val data_ok = Reg(Bool())
    val send_ex = (s === sLookup && (mem.st || mem.ld) || is_hit_inv) && ex
    val find_ok = s === sLookup && hit && cached && !ex
    val miss_ok = s === sMiss && fill_done && recv.done

    
    

    val l2sent = RegInit(False)
    val wait_l2 = Bool()
    wait_l2 := False
    val l2req = RegInit(False)
    def sendL2req():Unit = {
        l2req := True//Reg
        wait_l2 := True //Bool 
        l2sent := l2tlb_ready
    }
    update   := find_ok && mem.st
    read_dat := find_ok && mem.ld
    when(s === sOperate){
        assert(cacop.valid)
    }

    addr_ok := False
    miss_valid := False
    when(read_dat || update || miss_ok && mem.st || s === sWrite && miss.bvalid){
        data_ok := True
    }.elsewhen(resp.allow || s === sDied){
        data_ok := False
    }
    resp.valid := data_ok || recv.valid

    send_cop := cacop_working && cacop.isICache && !is_hit_inv

    switch(s){
        is(sIdle){
            addr_ok  := True
            l2sent := False
            cacop_working := req_cacop.valid
        }
        is(sLookup){
            //when(l1tlb_resp.hit || ex || l2tlb_valid){
            l2sent := False
            when(skip){
                addr_ok  := True
            }.elsewhen(ex){

            }.elsewhen(!l1tlb_resp.hit && !l2tlb_valid){
                sendL2req()
            }.elsewhen(hit && cached){
                addr_ok  := True
            }.otherwise{
                miss_valid := True
            }
        }
        is(sWaitL2){
            when(l2tlb_ready) {
                l2req := False
                l2sent := True
            }.elsewhen(!l2sent){
                l2req := True
                //l2sent := l2tlb_ready
            }
        }
        is(sWrite){
            when(miss.bvalid){
                addr_ok  := True
            }
        }
        is(sMiss){
            when(miss_ok){
                addr_ok  := True
            }
        }
        is(sWaitArready){
            when(cacop.valid && cacop.isDCache){
                when(miss.arready){
                    cacop_working := False
                    s := sIdle
                }
            }.otherwise{
                miss_valid := True
            }
        }
        is(sWaitIcopRdy){
            send_cop := True
            when(icop_rdy){
                cacop_working := False
                s := sIdle
            }
            
        }
    }

    //send_cop := s === sWaitIcopRdy//cacop_working && cacop.isICache
    inv_wrbk := my_hit_inv || my_idx_inv
    wrt_back := send_miss && cached || inv_wrbk
    replace  := wrt_back
    read_tag := False || my_hit_inv && s === sWaitArready
    read_tlb := False
    st_0_tag := False
    accept   := False
    when(init.valid || resp.cancel){
        s := sIdle
        l2req := False
        data_ok := False
    }.elsewhen(wait_l2){
        s := sWaitL2
    }.elsewhen(s === sWaitL2 && l2tlb_willresp){
        s := Mux(is_hit_inv, sOperate, sLookup)
    }.elsewhen(s === sLookup && ex && excp_ready){
        s := sDied
    }.elsewhen(addr_ok){
        when(req_valid){
            accept := True
            st_0_tag := req_st_0_tag
            read_tag := req_read_tag
            read_tlb := req_read_tlb
            when(req_cacop.valid){
                when(req_st_0_tag){
                    s := sIdle
                }.otherwise{
                    s := sOperate
                }
            }.otherwise{
                s := sLookup
            }
        }.elsewhen(s =/= sIdle){
            s := sIdle
        }
    }.elsewhen(miss_valid && !miss.arready){
            s := sWaitArready
    }.elsewhen(miss_valid && miss.arready){
        when(cached || mem.ld){
            s := sMiss
        }.otherwise{
            s := sWrite
        }
    }.elsewhen(s === sOperate){
        when(is_hit_inv && !l1tlb_resp.hit && !l2sent){
            //sendL2req()
            s := sWaitL2
            l2req := True//Reg
            l2sent := l2tlb_ready
        }.elsewhen(is_hit_inv && ex){
            when(excp_ready){
                s := sIdle
                cacop_working := False
            }
        }.elsewhen(cacop.isDCache){
            when(miss.arready){
                s := sIdle
                cacop_working := False
            }.otherwise{
                s := sWaitArready
            }
        }.elsewhen(cacop.isICache){
            send_cop := !ex
            when(icop_rdy){
                s := sIdle
                cacop_working := False
            }.otherwise{
                s := sWaitIcopRdy
            }
        }.otherwise{
            s := sIdle
            cacop_working := False
        }
    }
    recv.start := send_miss && mem.ld


    miss.arvalid := miss_valid &&(mem.ld || cached)
    miss.awvalid := ( miss_valid && mem.st && !cached
                    || replace && dirty)
    miss.occupy  := s === sLookup || s === sWaitArready || is_my_op

    def collect(info:DLookupInfo, cc:Bool):Unit = {
        ex     := info.ex && !info.mem.pr
        skip   := (info.ex || !cc) && info.mem.pr
        cached := cc
        mem    := info.mem
        cacop  := info.cacop
    }
    def collect(req:DataReq):Unit = {
        req_valid := req.valid
        req_cacop := req.cacop
        req.allow := addr_ok && !init.valid
    }
    def collect(icop:CacopBus){
        icop_rdy := icop.ready
        icop.valid := send_cop
    }
    def collect(fill:CacheFiller, iscached:Bool):Unit = {
        fill_done := fill.done
        fill.start := send_miss && iscached
    }
}

class DCache extends Component{
    val io = new Bundle{
        val dreq  :DataReq  = slave (new DataReq )
        val dexcpt:DExcept  = master(new DExcept )
        val dresp :DataResp = master(new DataResp)
        val dsrch = master(Stream(new AddrTransReq(LISA.TLB.DATA)))
        val dsrchresp = in(new AddrTransResp)
        val l2srch = master(Stream(new L2TLBSrchData))
        val l2srchresp = slave(Flow(new L2TLBResp))
        val l2will_resp = in(Bool())
        val dmiss :DMiss    = master(new DMiss   )
        val icop  :CacopBus = master(new CacopBus)
        //val pa  :UInt = out (LISA.PA)//difftest need
    }
    val cfg  = DCacheConfig
    val tagv = cfg.getTagV
    val data = cfg.getDDat
    val info = new DLookupInfo
    val repl = new LRUReplace
    val recv = new BusReceiver(io.dresp,io.dmiss)
    val ctrl = new DCacheCtrl(io.dmiss,io.dresp,recv)
    val (hits,hit) = info.tagCompare(tagv, io.dsrchresp.hit || (io.l2srchresp.valid && !ctrl.ex))
    ctrl.l1tlb_resp.assignSomeByName(io.dsrchresp)
    ctrl.l2tlb_valid := io.l2srchresp.valid
    //val hits = Bits()
    /*
    when(io.l2srchresp.valid){
        hits := io.l2srchresp.tagCompare(tagv)
    }.otherwise{
        hits := io.dsrchresp.tagCompare(tagv)
    }
    val hit = hits.orR
    when(io.l2srchresp.valid){
        ctrl.cached := io.l2srchresp.mat === B(1)
    }.elsewhen(io.dsrchresp.hit){
        ctrl.cached := io.dsrchresp.mat === U(1)
    }.otherwise{
        ctrl.cached := info.cached
    }
    */
    val need_translate = RegInit(False)
    need_translate.riseWhen(io.dreq.valid && (!io.dreq.cacop.valid || io.dreq.cacop.isHitInv))
    need_translate.fallWhen(io.dreq.translate_ok && !io.dreq.valid || io.dresp.cancel)
    io.dreq.translate_ok := io.dsrchresp.hit || io.l2srchresp.valid || !need_translate || ctrl.skip || ctrl.s === ctrl.CacheState.sIdle
    val buf_hits = Reg(hits)
    val buf_wpos = RegNextWhen(info.wpos,ctrl.read_dat)
    val buf_tagv = Reg(tagv(0).io.rd)
    val rtag = cfg.getVec1D(tagv(_).io.rd)
    val is_dirty:Bool = data.map(_.io.rb).reduce(_||_)
    val dat_line:Vec[UInt] = Vec(data.map(_.io.rd))
    val rval = dat_line(buf_wpos)
    val fill = new CacheFiller(io.dmiss)
    val cc = Bool
    val savedtlb = new SavedTlbres

    ctrl.l2tlb_willresp := io.l2will_resp
    ctrl.l2tlb_ready := io.l2srch.ready
    ctrl.dirty := is_dirty && buf_tagv.msb
    io.dsrch.va    := io.dreq.va
    io.dsrch.st    := io.dreq.mem.st
    io.dsrch.valid := ctrl.read_tlb
    ctrl.collect(io.dreq)
    ctrl.collect(info,cc)
    ctrl.collect(fill,cc)
    ctrl.excp_ready := io.dexcpt.allow

    ctrl.collect(io.icop)
    io.icop.code := info.cacop.code
    io.icop.ptag := info.ptag
    io.icop.vidx := info.vidx

    ctrl.hit := hit
    info.va := io.dreq.va
    info.assignSomeByName(savedtlb)
    //info.pt := U(0)
    //info.mat := U(0)
    //info.ex := False
    //info.ecode := U(0)
    //info.esubc := U(0)
    val l1resp_valid = RegNext(io.dsrch.valid)
    when(l1resp_valid){
        info.assignSomeByName(io.dsrchresp)
    }.elsewhen(io.l2srchresp.valid){
        //info.assignSomeByName(io.l2srchresp)
        info.pt := io.l2srchresp.ptag
        info.ecode := io.l2srchresp.ecode
        info.esubc := io.l2srchresp.esubc
        info.ex := io.l2srchresp.ex
        info.mat := io.l2srchresp.mat.asUInt
    }
    
    when(ctrl.accept && !io.l2will_resp){
        info.vidx := info.nextVIdx
        info.mem  := io.dreq.mem
        info.wdata  := io.dreq.wdata
        info.wstrb  := io.dreq.wstrb
        info.cacop  := io.dreq.cacop
    }
    
    
    val update_saved = l1resp_valid && io.dsrchresp.hit || io.l2srchresp.valid

    when(update_saved){
        savedtlb.pt := Mux(io.l2srchresp.valid, io.l2srchresp.ptag, io.dsrchresp.pt)
        savedtlb.mat := Mux(io.l2srchresp.valid, io.l2srchresp.mat.asUInt, io.dsrchresp.mat)
        savedtlb.ex := Mux(io.l2srchresp.valid, io.l2srchresp.ex, io.dsrchresp.ex)
        savedtlb.ecode := Mux(io.l2srchresp.valid, io.l2srchresp.ecode, io.dsrchresp.ecode)
        savedtlb.esubc := Mux(io.l2srchresp.valid, io.l2srchresp.esubc, io.dsrchresp.esubc)

    }.elsewhen(ctrl.s === ctrl.CacheState.sIdle || ctrl.s === ctrl.CacheState.sDied){
        //savedtlb.pt := U(0)
        //savedtlb.mat := U(0)
        savedtlb.ex := False
        //savedtlb.ecode := U(0)
        //savedtlb.esubc := U(0)

    }
    cc := Mux(update_saved, info.cached, savedtlb.cc)
    info.setFields(io.dmiss, update_saved, savedtlb)
    io.dexcpt.ex   := ctrl.send_ex
    info.send_ex(io.dexcpt)
    io.dresp.data  := Mux(recv.valid,recv.data,rval)
    val l2req_va = RegNextWhen(io.dreq.va,io.dreq.valid)
    val l2req_st = RegNextWhen(io.dreq.mem.st,io.dreq.valid)

    io.l2srch.valid := False
    io.l2srch.vppn := l2req_va.takeHigh(LISA.TLB.VPPN_LEN+1).asUInt
    io.l2srch.st := l2req_st
    when(ctrl.l2req){
        io.l2srch.valid := True
        //io.l2srch.vppn := l2req_va.takeHigh(LISA.TLB.VPPN_LEN+1).asUInt
    }

    val bus_size = U(2, 3 bits)
    when(!cc && !(info.cacop.isIdxInv && info.cacop.valid)){
        io.dmiss.arsize := info.mem.size
        io.dmiss.awsize := info.mem.size
        io.dmiss.awaddr := io.dmiss.araddr
        io.dmiss.wdata(0) := info.wdata
        for(i <- 1 until io.dmiss.wdata.length){
            io.dmiss.wdata(i) := U(0)
        }
        io.dmiss.wstrb  := info.wstrb
    }.otherwise{
        io.dmiss.arsize  := bus_size
        io.dmiss.awaddr  := buf_tagv(0, LISA.w_ppn bits) @@ info.widx @@ U(0, 4 bits)
        io.dmiss.awsize  := bus_size
        io.dmiss.wdata   := dat_line
        io.dmiss.wstrb.setAll()
    }
    val tagv_a = UInt(8 bits)
    val data_a = UInt(9 bits)
    when(ctrl.init.valid){
        tagv_a := ctrl.init.idx
    }.elsewhen(fill.valid || ctrl.inv_wrbk || ctrl.l2tlb_willresp){
        tagv_a := info.widx
    }.otherwise{
        tagv_a := info.ridx
    }
    val repl_mask = Mux(ctrl.inv_wrbk, B(0), cfg.getWayMask(repl.way))
    val hinv_mask = Mux(ctrl.my_hit_inv,hits,B(0))
    val iinv_mask = Mux(ctrl.my_idx_inv,cfg.getWayMask(info.vidx(0, 1 bits)),B(0))
    val wrbk_mask = hinv_mask | iinv_mask | repl_mask
    val sel_tagv  = MuxOH.or((repl_mask | iinv_mask).asBools,rtag)
    val hit_tagv  = Mux(ctrl.my_hit_inv,hit.asUInt @@ info.ptag,U(0))
    val wrbk_tagv = sel_tagv | hit_tagv
    val fill_wpos = fill.getPos(info)
    val fill_wmsk = Range(0,cfg.n_bank).map(fill_wpos === U(_))
    val bank_mask = Range(0,cfg.n_bank).map(info.wpos === U(_))
    val st_0_mask = Range(0,cfg.n_bank).map(info.va(cfg.w_pos-1 downto 0) === U(_))
    val fill_tagv = fill.isWriteTagV(io.dmiss)
    val fill_data = fill.isWriteData(io.dmiss)
    //val tlbresp:Bool = RegNext(io.dsrch.valid)
    //val saved_tag: UInt = RegNextWhen(info.pt, ctrl.send_miss) init(0)
    val tagv_wd:UInt = Mux(fill_tagv,U(1,1.bits)@@savedtlb.pt,U(0))
    //val tagv_wd:UInt = Mux(fill_tagv,info.tagv,U(0))
    val data_w_sel:Bool = info.mem.st && (fill.num === U(0) || !fill.valid)
    val data_ws:Bits = Mux(data_w_sel, info.wstrb, B(0))
    val data_wm:Bits = Mux(fill.valid,B"1111",info.wstrb)
    val data_wd:UInt = Range(3,-1,-1).map((i:Int)=>
        Mux(data_ws(i),info.wdata(i*8, 8 bits),io.dmiss.rdata(i*8, 8 bits))
    ).reduce(_ @@ _)
    cfg.each_way(tagv(_).io.wd := tagv_wd)
    cfg.each_way(tagv(_).io.a  := tagv_a )
    cfg.each_blk(data(_).io.wd := data_wd)
    cfg.each_blk(data(_).io.wm := data_wm)
    cfg.each_blk(data(_).io.wb := data_w_sel)
    cfg.each_blk(data(_).io.a  := data_a )
    when(ctrl.init.valid){
        cfg.each_way(tagv(_).io.setWrite())
    }.elsewhen(fill_tagv){
        cfg.each_way((i:Int)=>tagv(i).io.setWriteWhen(repl_mask(i)))
    }.elsewhen(ctrl.inv_wrbk){
        cfg.each_way((i:Int)=>tagv(i).io.setWriteWhen((wrbk_mask(i)) & io.dmiss.arready))
    }.elsewhen(ctrl.read_tag){
        cfg.each_way((i:Int)=>tagv(i).io.setRead())
    }.elsewhen(ctrl.st_0_tag){
        cfg.each_way((i:Int)=>tagv(i).io.setWriteWhen(st_0_mask(i)))
    }.otherwise{
        cfg.each_way(tagv(_).io.setIdle())
    }
    data_a(0, 8 bits) := info.widx
    when(fill_data){
        data_a(8, 1 bits) := OHToUInt(repl_mask)
        cfg.each_blk((i:Int)=>
            data(i).io.setWriteWhen(fill_wmsk(i))
        )
    }.elsewhen(ctrl.update){
        data_a(8, 1 bits) := OHToUInt(hits)
        cfg.each_blk((i:Int)=>
            data(i).io.setWriteWhen(bank_mask(i))
        )
    }.elsewhen(ctrl.read_dat){
        data_a(8, 1 bits) := OHToUInt(hits)
        cfg.each_blk((i:Int)=>
            data(i).io.setReadWhen(bank_mask(i))
        )
    }.elsewhen(ctrl.wrt_back){
        data_a(8, 1 bits) := OHToUInt(wrbk_mask)
        cfg.each_blk((i:Int)=>
            data(i).io.setReadWhen(True)
        )
    }.otherwise{
        data_a(8, 1 bits) := U(0)
        cfg.each_blk((i:Int)=>
            data(i).io.setIdle()
        )
    }
    when(ctrl.read_dat){
        buf_hits := hits
    }.elsewhen(ctrl.wrt_back){
        buf_tagv := wrbk_tagv
        buf_hits := wrbk_mask
    }
    when(ctrl.read_tag){
        repl.read(info.ridx)
    }
    when(ctrl.init.valid){
        repl.write(ctrl.init.idx,False)
    }.elsewhen(io.dresp.valid && io.dresp.allow){
        repl.write(info.widx,repl.update(fill.valid,buf_hits))
    }

/*==========================Performer Counter================================*/
    val pm_dcache_req = UInt(32 bits) setAsReg() init(0)
    when(io.dreq.valid && io.dreq.allow){
        pm_dcache_req := pm_dcache_req + U(1)
    }

    val pm_dcache_uncache_req = UInt(32 bits) setAsReg() init(0)
    when(ctrl.s === ctrl.CacheState.sLookup && !ctrl.cached && !ctrl.ex){
        pm_dcache_uncache_req := pm_dcache_uncache_req + U(1)
    }

    val pm_dcache_miss = UInt(32 bits) setAsReg() init(0)
    when(ctrl.s === ctrl.CacheState.sLookup && ctrl.cached && ctrl.miss_valid && !ctrl.ex){
        pm_dcache_miss := pm_dcache_miss + U(1)
    }
/*==========================Performer Counter================================*/
}